-
Notifications
You must be signed in to change notification settings - Fork 603
[PyTorch] Bunch of fixes for cpu offloading #2535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
d617edf to
6d2f43b
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
| # Only offload tensors with at least 256k elements (~1MB for float32) | ||
| if t.numel() < 256 * 1024: | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand, this is the reason we need to expose an option to disable bulk allocation in split_quantize? Bulk-allocated tensors hold on to memory untill all are deallocated, but this condition means that some small tensor might keep a large memory block alive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. And we cannot offload small tensors, because it causes the synchronization of compute/communication operations when CUDA_DEVICE_MAX_CONNECTIONS=1 is set - which is needed by the comm/gemm overlap.
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
|
/te-ci pytorch |
timmoon10
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Summary
This PR addresses CPU offloading performance and compatibility issues through multiple coordinated changes:
Key Changes
-
CPU Overhead Reduction: Added layer-level offloading skipping in
DefaultOffloadSynchronizer.push_tensor()to avoid processing tensors when a layer won't be offloaded. Also conditionally guardsmark_not_offload()calls. -
QuantizedTensor Offloading Support: Extended CPU offloading to handle QuantizedTensor types by decomposing them into component tensors, offloading each component recursively, and reconstructing them during reload.
-
DTensor Compatibility: Changed from
torch.empty(shape, device)totorch.empty_like(tensor, dtype)in FusedAdam to properly respect DTensor sharding annotations. -
Small Tensor Offloading Threshold: Added 256K element minimum threshold to prevent offloading of tiny tensors that would cause synchronization issues with
CUDA_DEVICE_MAX_CONNECTIONS=1. -
Bulk Allocation Control: Added
disable_bulk_allocationparameter tosplit_quantize()C++ function, enabled when CPU offloading is active to avoid grouping small tensors with large ones.
Files Modified
transformer_engine/pytorch/cpu_offload.py: Core offloading logic with QuantizedTensor supporttransformer_engine/pytorch/optimizers/fused_adam.py: DTensor-aware state initializationtransformer_engine/pytorch/module/linear.py: Conditional mark_not_offload() guardingtransformer_engine/pytorch/module/grouped_linear.py: disable_bulk_allocation parameter passingtransformer_engine/pytorch/quantized_tensor.py: Removed CPU operation validation that blocked offloading- C++ files: Added disable_bulk_allocation parameter and logic
- Tests: Updated tensor sizes to ensure components exceed 256K threshold
Issues Found
Critical Issue: The FusedAdam DTensor fix is incomplete. When a QuantizedTensor parameter wraps a DTensor, calling dequantize() creates a new plain tensor that loses DTensor sharding metadata. The fix should use the original parameter directly with .empty_like().
Type Annotation Issue: DefaultOffloadSynchronizer.push_tensor() return type annotation doesn't reflect actual return type (missing tuple[list, list]).
Behavioral Change: Default for retain_pinned_cpu_buffers changed from False to True, affecting memory usage patterns and performance characteristics. This change is not documented in the PR description.
Confidence Score: 2/5
- This PR has a critical bug that breaks DTensor parameter handling in FusedAdam, and incomplete type annotations. The DTensor fix is fundamentally broken for QuantizedTensor parameters.
- The PR contains one critical logic bug that makes the DTensor fix incomplete/incorrect. The FusedAdam change dequantizes QuantizedTensor parameters, which destroys DTensor sharding information that the empty_like() call is meant to preserve. Additionally, return type annotations are incomplete, and an undocumented behavioral default change (retain_pinned_cpu_buffers) could affect existing users. While the core CPU offloading improvements are sound, these issues need resolution before merging.
- transformer_engine/pytorch/optimizers/fused_adam.py (critical DTensor bug), transformer_engine/pytorch/cpu_offload.py (type annotation and default value)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/optimizers/fused_adam.py | 2/5 | FusedAdam state initialization broken for QuantizedTensor parameters with DTensor sharding. Calling dequantize() loses DTensor metadata that should be preserved with empty_like(). |
| transformer_engine/pytorch/cpu_offload.py | 3/5 | Multiple changes: QuantizedTensor offloading support added, CPU overhead reduced with layer skipping optimization, but return type annotation mismatch and behavioral default change. Default retain_pinned_cpu_buffers changed from False to True. |
| transformer_engine/pytorch/quantized_tensor.py | 4/5 | Removed CPU operation validation checks. This change is safe as it enables QuantizedTensor component tensors to be offloaded to CPU while preserving the original tensor type checks elsewhere. |
| transformer_engine/pytorch/module/grouped_linear.py | 5/5 | Disables bulk allocation in split_quantize when CPU offloading is active, addressing synchronization issues with small tensors and CUDA_DEVICE_MAX_CONNECTIONS=1. |
| transformer_engine/pytorch/csrc/extensions/cast.cpp | 5/5 | Correctly implements disable_bulk_allocation feature by wrapping bulk allocation logic in conditional, allowing fallback to unfused allocation when needed. |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Linear as Linear Module
participant CPU_Offload as CPU Offload System
participant QuantTensor as QuantizedTensor
participant OffloadState as OffloadableLayerState
User->>Linear: Forward pass with QuantizedTensor
Linear->>OffloadState: push_tensor(quantized_tensor)
alt QuantizedTensor Case
OffloadState->>QuantTensor: detach()
QuantTensor->>QuantTensor: prepare_for_saving()
Note over QuantTensor: Decompose into [data, transpose, scale_inv]
loop For each component
OffloadState->>OffloadState: push_tensor(component)
OffloadState->>OffloadState: Check if numel() > 256K
end
OffloadState-->>Linear: Return (push_results, tensor_objs)
else Regular Tensor Case
OffloadState->>OffloadState: Check if offload needed
OffloadState-->>Linear: Return tensor_id or tensor
end
Note over OffloadState: Backward pass
Linear->>OffloadState: pop_tensor(tensor_id_or_tuple)
alt QuantizedTensor Tuple Case
OffloadState->>OffloadState: pop_tensor(each_component)
OffloadState->>QuantTensor: restore_from_saved(components)
OffloadState-->>Linear: Return reconstructed QuantizedTensor
else Regular Case
OffloadState-->>Linear: Return tensor
end
| dtype = self.name_to_dtype_map[state_name] | ||
| # Handle QuantizedTensor by dequantizing first | ||
| param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param | ||
| if store_param_remainders: | ||
| data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) | ||
| data = torch.zeros_like(param_for_empty, dtype=torch.int16) | ||
| else: | ||
| data = torch.empty(param.shape, dtype=dtype, device=param.device) | ||
| data = torch.empty_like(param_for_empty, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dequantization of QuantizedTensor parameters here breaks DTensor sharding preservation. When param is a QuantizedTensor wrapping a DTensor, calling dequantize() creates a new plain tensor that loses the DTensor sharding metadata. This defeats the purpose of using .empty_like() to preserve DTensor sharding.
The fix should use the original parameter directly without dequantization, since .empty_like() respects the sharding annotations of the input tensor regardless of whether it's quantized:
| dtype = self.name_to_dtype_map[state_name] | |
| # Handle QuantizedTensor by dequantizing first | |
| param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param | |
| if store_param_remainders: | |
| data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) | |
| data = torch.zeros_like(param_for_empty, dtype=torch.int16) | |
| else: | |
| data = torch.empty(param.shape, dtype=dtype, device=param.device) | |
| data = torch.empty_like(param_for_empty, dtype=dtype) | |
| data = torch.zeros_like(param, dtype=torch.int16) | |
| ... | |
| data = torch.empty_like(param, dtype=dtype) |
Alternatively, if dequantization is necessary for some reason, the sharding information from the original parameter should be explicitly preserved.
| def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: | ||
| """Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead.""" | ||
| if not self.offload_layer_map.get(self.num_of_fwds, False): | ||
| return tensor | ||
| return self.layer_states[self.num_of_fwds].push_tensor(tensor) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type annotation for DefaultOffloadSynchronizer.push_tensor() is incomplete and doesn't match the base class signature. This method can return a tuple (when processing QuantizedTensors), but the annotation only specifies int | torch.Tensor.
| def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: | |
| """Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead.""" | |
| if not self.offload_layer_map.get(self.num_of_fwds, False): | |
| return tensor | |
| return self.layer_states[self.num_of_fwds].push_tensor(tensor) | |
| def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]: | |
| """Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead.""" | |
| if not self.offload_layer_map.get(self.num_of_fwds, False): | |
| return tensor | |
| return self.layer_states[self.num_of_fwds].push_tensor(tensor) |
| num_layers: Optional[int] = 1, | ||
| model_layers: int = 1, | ||
| offload_activations: bool = True, | ||
| offload_weights: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value for retain_pinned_cpu_buffers has been changed from False to True. While this may improve performance by reusing CPU buffers across iterations, this is a significant behavioral change that affects memory usage patterns and is not mentioned in the PR description. Consider documenting this change in the commit message and PR description, as it could impact existing users' performance characteristics.
If backward compatibility is important, consider keeping the default as False or provide a migration path for existing code.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
This PR addresses several issues related to CPU offloading performance and compatibility.
1. CPU Overhead Reduction
This PR reduces CPU overhead through multiple optimizations:
__torch_function__hook: Previously costly validation checks have been eliminated.2. Out of Memory Error with Fused Optimizer and DTensor
PyTorch introduced JAX-like DTensor, and some workloads use our fused optimizer with this tensor type. The previous implementation used
.empty_like, which works correctly for standard tensors but does not respect sharding for DTensor—resulting in full tensors being created on each device. This has been fixed by switching to.emptywith explicit shape specification.3. Synchronization Issues When Offloading Small Tensors
For grouped tensors, allocation is performed in bulk, requiring an all-or-nothing offloading approach. This meant small tensors like scales were also offloaded, which caused issues with comm-gemm overlap when
CUDA_DEVICE_MAX_CONNECTIONS=1was set. In these cases, tensors were small enough that SMs were used for copying instead of copy engines, leading to synchronization problems.Fixes:
Type of change
Checklist: